package v1

import (
	"context"
	"crypto/rand"
	"encoding/base64"
	"errors"
	"fmt"
	"gitee.com/zackeus/go-boot/scs/v1/store"
	"gitee.com/zackeus/goutil"
	"gitee.com/zackeus/goutil/strutil"
	"net/http"
	"sort"
	"sync"
	"sync/atomic"
	"time"
)

type (
	contextKey string
	// Status session 状态值
	Status int

	sessionData struct {
		uid      string
		deadline time.Time
		status   Status
		token    string
		values   map[string]interface{}
		mu       sync.Mutex
	}

	// ContextHandle context 自定义处理
	ContextHandle func(ctx context.Context, uid string, token string) context.Context
	// UnauthorizedCallback 未鉴权回调
	UnauthorizedCallback func(http.ResponseWriter, *http.Request, http.Handler)
	// ErrorCallback 异常回调函数
	ErrorCallback func(http.ResponseWriter, *http.Request, error)
	// IterateCallback 迭代回调
	IterateCallback func(uid string, token string, deadline time.Time, values map[string]any) error
)

const (
	// Unmodified indicates that the session data hasn't been changed in the
	// current request cycle.
	Unmodified Status = iota

	// Modified indicates that the session data has been changed in the current
	// request cycle.
	Modified

	// Destroyed indicates that the session data has been destroyed in the
	// current request cycle.
	Destroyed
)

var (
	contextKeyID      uint64
	contextKeyIDMutex = &sync.Mutex{}
)

// sessionData 生成器
func newSessionData(uid string, lifetime time.Duration) *sessionData {
	return &sessionData{
		uid:      uid,
		deadline: time.Now().Add(lifetime).UTC(),
		status:   Unmodified,
		values:   make(map[string]interface{}),
	}
}

// contextKey 生成器
func generateContextKey() contextKey {
	contextKeyIDMutex.Lock()
	defer contextKeyIDMutex.Unlock()
	atomic.AddUint64(&contextKeyID, 1)
	return contextKey(fmt.Sprintf("session.%d", contextKeyID))
}

// token 生成器
func generateToken() (string, error) {
	b := make([]byte, 32)
	_, err := rand.Read(b)
	if err != nil {
		return "", err
	}
	return base64.RawURLEncoding.EncodeToString(b), nil
}

// 获取 sessionData
func (s *SessionManager) getSessionDataFromContext(ctx context.Context) (*sessionData, error) {
	c, ok := ctx.Value(s.contextKey).(*sessionData)
	if !ok {
		return nil, NoSessionData
	}
	return c, nil
}

// 绑定 sessionData
func (s *SessionManager) addSessionDataToContext(ctx context.Context, sd *sessionData) context.Context {
	return context.WithValue(ctx, s.contextKey, sd)
}

// session store commit
func (s *SessionManager) doStoreCommit(ctx context.Context, token string, b []byte, expiry time.Time) (err error) {
	c, ok := s.store.(interface {
		CommitCtx(context.Context, string, string, []byte, time.Time) error
	})
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return err
	}

	if ok {
		return c.CommitCtx(ctx, sd.uid, token, b, expiry)
	}
	return s.store.Commit(sd.uid, token, b, expiry)
}

// session store Get
func (s *SessionManager) doStoreGet(ctx context.Context, token string) (b []byte, found bool, err error) {
	c, ok := s.store.(interface {
		GetCtx(context.Context, string) ([]byte, bool, error)
	})
	if ok {
		return c.GetCtx(ctx, token)
	}
	return s.store.Get(token)
}

func (s *SessionManager) doStoreFind(ctx context.Context, uid string) (map[string][]byte, error) {
	c, ok := s.store.(interface {
		FindCtx(context.Context, string) (map[string][]byte, error)
	})
	if ok {
		return c.FindCtx(ctx, uid)
	}
	return s.store.Find(uid)
}

func (s *SessionManager) doStoreAll(ctx context.Context) (map[string][]byte, error) {
	cs, ok := s.store.(store.IterableCtxStore)
	if ok {
		return cs.AllCtx(ctx)
	}

	is, ok := s.store.(store.IterableStore)
	if ok {
		return is.All()
	}

	return map[string][]byte{}, errors.New(fmt.Sprintf("type %T does not support iteration", s.store))
}

// session store delete
func (s *SessionManager) doStoreDelete(ctx context.Context, uid string, tokens ...string) (err error) {
	c, ok := s.store.(interface {
		DeleteCtx(context.Context, string, []string) error
	})

	if ok {
		return c.DeleteCtx(ctx, uid, tokens)
	}
	return s.store.Delete(uid, tokens)
}

// session store clear
func (s *SessionManager) doStoreClear(ctx context.Context, uids []string) (err error) {
	c, ok := s.store.(interface {
		ClearCtx(context.Context, []string) error
	})

	if ok {
		return c.ClearCtx(ctx, uids)
	}
	return s.store.Clear(uids)
}

// Status 获取 session 状态
func (s *SessionManager) Status(ctx context.Context) (Status, error) {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return Destroyed, err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	return sd.status, nil
}

// Uid 返回唯一标识。请注意，如果在会话提交到存储之前调用它，它将返回空字符串 ""
func (s *SessionManager) Uid(ctx context.Context) (string, error) {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return "", err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	return sd.uid, nil
}

// Token 返回会话令牌。请注意，如果在会话提交到存储之前调用它，它将返回空字符串 ""
func (s *SessionManager) Token(ctx context.Context) (string, error) {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return "", err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	return sd.token, nil
}

// RememberMe 控制会话 cookie 是否持久（即是否在用户关闭浏览器后保留）
// RememberMe 仅设置 SessionManager.Cookie.Persist = false（默认为 true）并且调用 WriteSessionCookie 时有效
func (s *SessionManager) RememberMe(ctx context.Context, val bool) error {
	return s.Put(ctx, "__rememberMe", val)
}

// Iterate 从存储中检索所有活动的（即未过期的）会话并为每个会话执行提供的函数 call
// uid: 可选项 指定 uid 索引
func (s *SessionManager) Iterate(ctx context.Context, call IterateCallback, uid ...string) error {
	var (
		allSessions map[string][]byte
		err         error
	)

	if !goutil.IsEmpty(uid) && strutil.IsNotBlank(uid[0]) {
		allSessions, err = s.doStoreFind(ctx, uid[0])
	} else {
		allSessions, err = s.doStoreAll(ctx)
	}
	if err != nil {
		return err
	}

	for token, b := range allSessions {
		uid, deadline, values, err := s.codec.Decode(b)
		if err != nil {
			return err
		}
		err = call(uid, token, deadline, values)
		if err != nil {
			return err
		}
	}
	return nil
}

// Deadline 返回session的绝对到期时间。请注意，如果使用 IdleTimeout，则会话可能会由于在返回的截止日期之前未使用而过期
func (s *SessionManager) Deadline(ctx context.Context) (time.Time, error) {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return time.Time{}, err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	return sd.deadline, nil
}

// ClearData 删除当前会话的所有数据。会话令牌和生命周期不受影响。如果当前会话中没有数据，则这是一个空操作
func (s *SessionManager) ClearData(ctx context.Context) error {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	if len(sd.values) == 0 {
		return nil
	}

	for key := range sd.values {
		delete(sd.values, key)
	}
	sd.status = Modified
	return nil
}

// Exists 如果会话数据中存在给定键，则返回 true
func (s *SessionManager) Exists(ctx context.Context, key string) (bool, error) {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return false, err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()
	_, exists := sd.values[key]

	return exists, nil
}

// Keys 返回会话数据中存在的所有键名称的一部分，按字母顺序排序。如果数据不包含任何数据，则将返回一个空切片
func (s *SessionManager) Keys(ctx context.Context) ([]string, error) {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return []string{}, err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	keys := make([]string, len(sd.values))
	i := 0
	for key := range sd.values {
		keys[i] = key
		i++
	}
	sort.Strings(keys)
	return keys, nil
}

// Put 向会话数据添加键和相应的值。键的任何现有值都将被替换。会话数据状态将设置为 Modified
func (s *SessionManager) Put(ctx context.Context, key string, val interface{}) error {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	sd.values[key] = val
	sd.status = Modified

	return nil
}

// Get 从会话数据返回给定键的值
// 返回值的类型为 interface{} 因此通常需要在使用它之前进行类型断言
func (s *SessionManager) Get(ctx context.Context, key string) (interface{}, error) {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return nil, err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	return sd.values[key], nil
}

// Pop 就像一次性 Get 一样。返回会话数据中给定键的值，并从会话数据中删除键和值
// 会话数据状态将设置为已修改。返回值的类型为 interface{} 因此通常需要在使用它之前进行类型断言
func (s *SessionManager) Pop(ctx context.Context, key string) (interface{}, error) {
	sd, err := s.getSessionDataFromContext(ctx)
	if goutil.IsEqual(NoSessionData, err) {
		/* sessionData 不存在 直接返回 */
		return nil, nil
	}
	if err != nil {
		return nil, err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	val, exists := sd.values[key]
	if !exists {
		return nil, nil
	}
	delete(sd.values, key)
	sd.status = Modified

	return val, nil
}

// GetBool 从会话数据返回给定键的 bool 值。如果键不存在或无法将值类型断言为 bool，则返回 bool 的零值 (false)。
func (s *SessionManager) GetBool(ctx context.Context, key string) (bool, error) {
	val, err := s.Get(ctx, key)
	if err != nil {
		return false, err
	}

	b, ok := val.(bool)
	if !ok {
		return false, nil
	}
	return b, nil
}

// GetString 从会话数据返回给定键的字符串值。如果键不存在或无法将值类型断言为字符串，则返回字符串 ("") 的零值
func (s *SessionManager) GetString(ctx context.Context, key string) (string, error) {
	val, err := s.Get(ctx, key)
	if err != nil {
		return "", err
	}

	str, ok := val.(string)
	if !ok {
		return "", nil
	}
	return str, nil
}

// GetInt 从会话数据返回给定键的 int 值。如果键不存在或无法将值类型断言为 int，则返回 int 的零值 (0)。
func (s *SessionManager) GetInt(ctx context.Context, key string) (int, error) {
	val, err := s.Get(ctx, key)
	if err != nil {
		return 0, err
	}

	i, ok := val.(int)
	if !ok {
		return 0, nil
	}
	return i, nil
}

// GetInt64 从会话数据中返回给定键的 int64 值。如果键不存在或无法将值类型断言为 int64，则返回 int64 的零值 (0)
func (s *SessionManager) GetInt64(ctx context.Context, key string) (int64, error) {
	val, err := s.Get(ctx, key)
	if err != nil {
		return 0, err
	}

	i, ok := val.(int64)
	if !ok {
		return 0, nil
	}
	return i, nil
}

// GetInt32 从会话数据返回给定键的 int 值。如果键不存在或无法将值类型断言为 int32，则返回 int32 的零值 (0)
func (s *SessionManager) GetInt32(ctx context.Context, key string) (int32, error) {
	val, err := s.Get(ctx, key)
	if err != nil {
		return 0, err
	}

	i, ok := val.(int32)
	if !ok {
		return 0, nil
	}
	return i, nil
}

// GetFloat 从会话数据中返回给定键的 float64 值。如果键不存在或无法将值类型断言为 float64，则返回 float64 的零值 (0)
func (s *SessionManager) GetFloat(ctx context.Context, key string) (float64, error) {
	val, err := s.Get(ctx, key)
	if err != nil {
		return 0, err
	}

	f, ok := val.(float64)
	if !ok {
		return 0, nil
	}
	return f, nil
}

// GetBytes 返回会话数据中给定键的字节切片 ([]byte) 值。如果键不存在或无法将类型断言为 []byte，则返回切片的零值 (nil)
func (s *SessionManager) GetBytes(ctx context.Context, key string) ([]byte, error) {
	val, err := s.Get(ctx, key)
	if err != nil {
		return nil, err
	}

	b, ok := val.([]byte)
	if !ok {
		return nil, nil
	}
	return b, nil
}

// GetTime 从会话数据中返回给定键的 time.Time 值。
// 如果键不存在或无法将值类型声明为 time.Time，则返回 time.Time 对象的零值。这可以使用 time.IsZero() 方法进行测试
func (s *SessionManager) GetTime(ctx context.Context, key string) (time.Time, error) {
	val, err := s.Get(ctx, key)
	if err != nil {
		return time.Time{}, err
	}

	t, ok := val.(time.Time)
	if !ok {
		return time.Time{}, nil
	}
	return t, nil
}

// PopString 返回给定键的字符串值，然后将其从会话数据中删除。会话数据状态将设置为已修改。
// 如果键不存在或无法将值类型断言为字符串，则返回字符串 ("") 的零值
func (s *SessionManager) PopString(ctx context.Context, key string) (string, error) {
	val, err := s.Pop(ctx, key)
	if err != nil {
		return "", err
	}

	str, ok := val.(string)
	if !ok {
		return "", nil
	}
	return str, nil
}

// PopBool 返回给定键的 bool 值，然后将其从会话数据中删除。会话数据状态将设置为已修改。
// 如果键不存在或无法将值类型断言为 bool，则返回 bool 的零值 (false)
func (s *SessionManager) PopBool(ctx context.Context, key string) (bool, error) {
	val, err := s.Pop(ctx, key)
	if err != nil {
		return false, err
	}

	b, ok := val.(bool)
	if !ok {
		return false, nil
	}
	return b, nil
}

// PopInt 返回给定键的 int 值，然后将其从会话数据中删除。会话数据状态将设置为已修改。
// 如果键不存在或无法将值类型断言为 int，则返回 int 的零值 (0)
func (s *SessionManager) PopInt(ctx context.Context, key string) (int, error) {
	val, err := s.Pop(ctx, key)
	if err != nil {
		return 0, err
	}

	i, ok := val.(int)
	if !ok {
		return 0, nil
	}
	return i, nil
}

// PopFloat 返回给定键的 float64 值，然后将其从会话数据中删除。
// 会话数据状态将设置为已修改。如果键不存在或无法将值类型断言为 float64，则返回 float64 的零值 (0)
func (s *SessionManager) PopFloat(ctx context.Context, key string) (float64, error) {
	val, err := s.Pop(ctx, key)
	if err != nil {
		return 0, err
	}

	f, ok := val.(float64)
	if !ok {
		return 0, nil
	}
	return f, nil
}

// PopBytes 返回给定键的字节切片 ([]byte) 值，然后将其从会话数据中删除。
// 会话数据状态将设置为已修改。如果键不存在或无法将类型断言为 []byte，则返回切片的零值 (nil)
func (s *SessionManager) PopBytes(ctx context.Context, key string) ([]byte, error) {
	val, err := s.Pop(ctx, key)
	if err != nil {
		return nil, err
	}

	b, ok := val.([]byte)
	if !ok {
		return nil, nil
	}
	return b, nil
}

// PopTime 返回给定键的 time.Time 值，然后将其从会话数据中删除。会话数据状态将设置为已修改。
// 如果键不存在或无法将值类型声明为 time.Time，则返回 time.Time 对象的零值
func (s *SessionManager) PopTime(ctx context.Context, key string) (time.Time, error) {
	val, err := s.Pop(ctx, key)
	if err != nil {
		return time.Time{}, err
	}

	t, ok := val.(time.Time)
	if !ok {
		return time.Time{}, nil
	}
	return t, nil
}
